'''Generate Cython PYX wrappers for Boost stats distributions.'''

from typing import NamedTuple
from warnings import warn
from textwrap import dedent
from shutil import copyfile
import pathlib
import os
import argparse

from gen_func_defs_pxd import (  # type: ignore
    _gen_func_defs_pxd)
from _info import (  # type: ignore
    _x_funcs, _no_x_funcs, _klass_mapper)

class _MethodDef(NamedTuple):
    ufunc_name: str
    num_inputs: int
    boost_func_name: str


def _ufunc_gen(scipy_dist: str, types: list, ctor_args: tuple,
               filename: str, boost_dist: str, x_funcs: list,
               no_x_funcs: list, distutils_build: bool):
    '''
    We need methods defined for each rv_continuous/_discrete internal method:
        i.e.: _pdf, _cdf, etc.
    Some of these methods take constructor arguments and 1 extra argument,
        e.g.: _pdf(x, *ctor_args), _ppf(q, *ctor_args)
    while some of the methods take only constructor arguments:
        e.g.: _stats(*ctor_args)
    '''
    num_ctor_args = len(ctor_args)
    methods = [_MethodDef(
        ufunc_name=f'_{scipy_dist}_{x_func}',
        num_inputs=num_ctor_args+1,  # +1 for the x argument
        # PDF for the beta distribution has a custom wrapper:
        boost_func_name=x_func if boost_dist != 'beta_distribution'
        else 'pdf_beta' if x_func == 'pdf' else x_func,
    ) for x_func in x_funcs]
    methods += [_MethodDef(
        ufunc_name=f'_{scipy_dist}_{func}',
        num_inputs=num_ctor_args,
        boost_func_name=func,
    ) for func in no_x_funcs]

    # Identify potential ufunc issues:
    no_input_methods = [m for m in methods if m.num_inputs == 0]
    if no_input_methods:
        raise ValueError("ufuncs must have >0 arguments! "
                         f"Cannot construct these ufuncs: {no_input_methods}")

    boost_hdr_name = boost_dist.split('_distribution')[0]
    unique_num_inputs = set({m.num_inputs for m in methods})
    has_NPY_FLOAT16 = 'NPY_FLOAT16' in types
    has_NPY_LONGDOUBLE = 'NPY_LONGDOUBLE' in types
    line_joiner = ',\n    ' + ' '*12
    num_types = len(types)
    loop_fun = 'PyUFunc_T'
    func_defs_cimports = line_joiner.join(
        f"boost_{m.boost_func_name}{num_ctor_args}" for m in methods)
    nontype_params = line_joiner[1:].join(
        f'ctypedef int NINPUTS{n} "{n}"' for n in unique_num_inputs)

    with open(filename, 'w') as fp:
        boost_hdr = f'boost/math/distributions/{boost_hdr_name}.hpp'
        if distutils_build:
            # There's no __init__.py here, so no `from .xxx cimport`
            relimport = ''
        else:
            relimport = '.'

        fp.write(dedent(f'''\
            # distutils: language = c++
            # cython: language_level=3

            # This file was generated by stats/_boost/include/code_gen.py
            # All modifications to this file will be overwritten.

            from numpy cimport (
                import_array,
                import_ufunc,
                PyUFunc_FromFuncAndData,
                PyUFuncGenericFunction,
                PyUFunc_None,
                {line_joiner.join(types)}
            )
            from {relimport}templated_pyufunc cimport PyUFunc_T
            from {relimport}func_defs cimport (
                {func_defs_cimports},
            )
            cdef extern from "{boost_hdr}" namespace "boost::math" nogil:
                cdef cppclass {boost_dist} nogil:
                    pass

            # Workaround for Cython's lack of non-type template parameter
            # support
            cdef extern from * nogil:
                {nontype_params}

            _DUMMY = ""
            import_array()
            import_ufunc()
            '''))

        if has_NPY_LONGDOUBLE:
            fp.write('ctypedef long double longdouble\n\n')
        if has_NPY_FLOAT16:
            warn('Boost stats NPY_FLOAT16 ufunc generation not '
                 'currently not supported!')

        # Generate ufuncs for each method
        for ii, m in enumerate(methods):
            fp.write(dedent(f'''
                cdef PyUFuncGenericFunction loop_func{ii}[{num_types}]
                cdef void* func{ii}[1*{num_types}]
                cdef char types{ii}[{m.num_inputs+1}*{num_types}]
                '''))  # m.num_inputs+1 for output arg

            for jj, T in enumerate(types):
                ctype = {
                    'NPY_LONGDOUBLE': 'longdouble',
                    'NPY_DOUBLE': 'double',
                    'NPY_FLOAT': 'float',
                    'NPY_FLOAT16': 'npy_half',
                }[T]
                boost_fun = f'boost_{m.boost_func_name}{num_ctor_args}'
                type_str = ", ".join([ctype]*(1+num_ctor_args))
                boost_tmpl = f'{boost_dist}, {type_str}'
                N = m.num_inputs
                fp.write(f'''\
loop_func{ii}[{jj}] = <PyUFuncGenericFunction>{loop_fun}[{ctype}, NINPUTS{N}]
func{ii}[{jj}] = <void*>{boost_fun}[{boost_tmpl}]
''')
                for tidx in range(m.num_inputs+1):
                    fp.write(
                        f'types{ii}[{tidx}+{jj}*{m.num_inputs+1}] = {T}\n')
            arg_list_str = ', '.join(ctor_args)
            if m.boost_func_name in x_funcs:
                arg_list_str = 'x, ' + arg_list_str
            fp.write(dedent(f'''
                {m.ufunc_name} = PyUFunc_FromFuncAndData(
                    loop_func{ii},
                    func{ii},
                    types{ii},
                    {num_types},  # number of supported input types
                    {m.num_inputs},  # number of input args
                    1,  # number of output args
                    PyUFunc_None,  # `identity` element, never mind this
                    "{m.ufunc_name}",  # function name
                    ("{m.ufunc_name}({arg_list_str}) -> computes "
                     "{m.boost_func_name} of {scipy_dist} distribution"),
                    0  # unused
                )
                '''))


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("-o", "--outdir", type=str,
                        help="Path to the output directory")
    parser.add_argument("--distutils-build", type=bool, default=False,
                        help="Whether or not this is a distutils build")
    args = parser.parse_args()

    _boost_dir = pathlib.Path(__file__).resolve().parent.parent
    if args.outdir:
        src_dir = pathlib.Path(args.outdir)
    else:
        # We're using setup.py here, not Meson. Create target directory
        src_dir = _boost_dir / 'src'
        src_dir.mkdir(exist_ok=True, parents=True)

    # copy contents of include into directory to satisfy Cython
    # PXD include conditions
    inc_dir = _boost_dir / 'include'
    src = 'templated_pyufunc.pxd'
    copyfile(inc_dir / src, src_dir / src)

    # generate the PXD and PYX wrappers
    _gen_func_defs_pxd(
        f'{src_dir}/func_defs.pxd',
        x_funcs=_x_funcs,
        no_x_funcs=_no_x_funcs)
    for b, s in _klass_mapper.items():
        _ufunc_gen(
            scipy_dist=s.scipy_name,
            types=['NPY_FLOAT', 'NPY_DOUBLE', 'NPY_LONGDOUBLE'],
            ctor_args=s.ctor_args,
            filename=f'{src_dir}/{s.scipy_name}_ufunc.pyx',
            boost_dist=f'{b}_distribution',
            x_funcs=_x_funcs,
            no_x_funcs=_no_x_funcs,
            distutils_build=args.distutils_build,
        )
